package spark
import scala.collection.{mutable, Map}
import org.apache.spark.graphx._

/** Label Propagation algorithm. */
/** Python并不兼容RDD，所以得在Scala中编写 */
/** 所更新的LPA算法弥补了传统LPA算法的缺点
 * 1.实际场景需要使用坏标签对已有的好标签进行标注，原算法是使用VertexId打标注；
 * 2.实际场景中坏标签应该需要被固定，它不会被其他好标签影响，原算法不考虑固定标签；
 * 3.实际场景边的权重需要被考虑进去，原算法中边的权重全部为1； */
object LabelPropagation_ extends Serializable{
  def run(graph: Graph[String, Double], maxSteps: Int, keepLabel:Array[VertexId]): Graph[String, Double] = {
    require(maxSteps > 0, s"Maximum of steps must be greater than 0, but got ${maxSteps}")

    val lpaGraph = graph.mapVertices { case (vid, label) => label}
    def sendMessage(e: EdgeTriplet[String, Double]): Iterator[(VertexId, Map[String, Double])] = {
      if (keepLabel.contains(e.srcId) & keepLabel.contains(e.dstId)){
        Iterator()
      }
      else if (keepLabel.contains(e.srcId)){
        Iterator((e.dstId, Map(e.srcAttr -> e.attr)))
      }
      else if (keepLabel.contains(e.dstId)){
        Iterator((e.srcId, Map(e.dstAttr -> e.attr)))
      }
      else{
        Iterator((e.srcId, Map(e.dstAttr -> e.attr)), (e.dstId, Map(e.srcAttr -> e.attr)))}
    }
    def mergeMessage(count1: Map[String, Double], count2: Map[String, Double])
    : Map[String, Double] = {
      // Mimics the optimization of breakOut, not present in Scala 2.13, while working in 2.12
      val map = mutable.Map[String, Double]()
      (count1.keySet ++ count2.keySet).foreach { i =>
        val count1Val = count1.getOrElse(i, 0.0)
        val count2Val = count2.getOrElse(i, 0.0)
        map.put(i, count1Val + count2Val)
      }
      map
    }
    def vertexProgram(vid: VertexId, attr: String, message: Map[String, Double]): String = {
      if (message.isEmpty) attr else message.maxBy(_._2)._1
    }
    val initialMessage = Map[String, Double]()
    Pregel(lpaGraph, initialMessage, maxIterations = maxSteps)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)
  }
}